package org.example.model;


import org.tensorflow.Graph;
import org.tensorflow.Session;
import org.tensorflow.Tensor;

import java.nio.file.Files;
import java.nio.file.Paths;

/**
 * @author liyishan
 * @date 2024/3/8 17:41
 * @apiNote
 */

public class LinearRegressionModel {

    public static void main(String[] args)throws Exception {
        float[] xs = {0,1,2,3,4,5};
        float[] ys = {0,2,3,6,8,19};

        Graph graph = new Graph();
        try(Session session = new Session(graph)){
            float[] m = {0};
            float[] b = {0};

            for(int i = 0;i < 100;i++){
                try(Tensor x = Tensor.create(xs);
                    Tensor y = Tensor.create(ys)){
                    session.runner().feed("x", x).feed("y", y).fetch("update").run();
                }
                session.runner().fetch("m/read").fetch("b/read").run();
                m = session.runner().fetch("m/read").run().get(0).copyTo(new float[1]);
                b = session.runner().fetch("b/read").run().get(0).copyTo(new float[1]);
            }
            Files.write(Paths.get("linear_model","m.txt"),String.valueOf(m[0]).getBytes());
            Files.write(Paths.get("linear_model","b.txt"),String.valueOf(b[0]).getBytes());

        }

    }
}
